Pinvon's Blog

所见, 所闻, 所思, 所想

pytest

快速入门

安装

pip install pytest

测试

成功

def test_passing():
    assert (1, 2, 3) == (1, 2, 3)

$ pytest test.py
...
test.py .

. 表示测试成功

失败

def test_failing():
    assert (1, 2, 3) == (3, 2, 1)

$ pytest test.py
...
test.py F
...

F 表示测试失败

介绍

Mock 类图

0.png

Mock 参数

参数 name

name 定义了 mock 对象的唯一标识符.

from mock import Mock

# create the mock object
mockFoo = Mock(name = "Foo")

print mockFoo
# returns: <Mock name='Foo' id='494864'>
print repr(mockFoo)
# still returns: <Mock name='Foo' id='494864'>
参数 spec

使用 spec 来定义 mock 对象的属性.

from mock import Mock

# prepare the spec list
fooSpec = ["_fooValue", "callFoo", "doFoo"]

# create the mock object
mockFoo = Mock(spec = fooSpec)

# accessing the mocked attributes
print mockFoo
# <Mock id='427280'>
print mockFoo._fooValue
# returns <Mock name='mock._fooValue' id='2788112'>
print mockFoo.callFoo()
# returns: <Mock name='mock.callFoo()' id='2815376'>

mockFoo.callFoo()
# nothing happens, which is fine

# accessing the missing attributes
print mockFoo._fooBar
# raises: AttributeError: Mock object has no attribute '_fooBar'
mockFoo.callFoobar()
# raises: AttributeError: Mock object has no attribute 'callFoobar'

可以使用类名作为 spec 的值, 这样就使得 mock 对象与传入的类具有相同属性.

from mock import Mock

# The class interfaces
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue    

# create the mock object
mockFoo = Mock(spec = Foo)

# accessing the mocked attributes
print mockFoo
# returns <Mock spec='Foo' id='507120'>
print mockFoo._fooValue
# returns <Mock name='mock._fooValue' id='2788112'>
print mockFoo.callFoo()
# returns: <Mock name='mock.callFoo()' id='2815376'>

mockFoo.callFoo()
# nothing happens, which is fine

# accessing the missing attributes
print mockFoo._fooBar
# raises: AttributeError: Mock object has no attribute '_fooBar'
mockFoo.callFoobar()
# raises: AttributeError: Mock object has no attribute 'callFoobar'
参数 return_value

设置 mock 对象被直接调用时返回的值.

from mock import Mock

# create the mock object
mockFoo = Mock(return_value = 456)

print mockFoo
# <Mock id='2787568'>

mockObj = mockFoo()
print mockObj
# returns: 456
from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

# creating the mock object
fooObj = Foo()
print fooObj
# returns: <__main__.Foo object at 0x68550>

mockFoo = Mock(return_value = fooObj)
print mockFoo
# returns: <Mock id='2788144'>

# creating an "instance"
mockObj = mockFoo()
print mockObj
# returns: <__main__.Foo object at 0x68550>

# working with the mocked instance
print mockObj._fooValue
# returns: 123
mockObj.callFoo()
# returns: Foo:callFoo_
mockObj.doFoo("narf")
# returns: Foo:doFoo:input =  narf
<Mock id='428560'>
参数 side_effect

如果设置了 side_effect 的值, 它会覆盖 return_value, 当调用 mock 对象时, 会返回 side_effect 的值, 而不是 return_value 的值.

from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

# creating the mock object (without a side effect)
fooObj = Foo()

mockFoo = Mock(return_value = fooObj)
print mockFoo
# returns: <Mock id='2788144'>

# creating an "instance"
mockObj = mockFoo()
print mockObj
# returns: <__main__.Foo object at 0x2a88f0>

# creating a mock object (with a side effect)

mockFoo = Mock(return_value = fooObj, side_effect = StandardError)
mockObj = mockFoo()
# raises: StandardError

还有一种用法, 是如果 side_effect 内部有多个值, 每次调用会返回不同的值.

from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

# creating the mock object (with a side effect)
fooObj = FooSpec()

fooList = [665, 666, 667]
mockFoo = Mock(return_value = fooObj, side_effect = fooList)

fooTest = mockFoo()
print fooTest
# returns 665

fooTest = mockFoo()
print fooTest
# returns 666

fooTest = mockFoo()
print fooTest
# returns 667

fooTest = mockFoo()
print fooTest
# raises: StopIteration

断言

assert_called_with()

检查 mock 方法是否获得了正确的参数.

from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        pass
    
    def doFoo(self, argValue):
        pass

# create the mock object
mockFoo = Mock(spec = Foo)
print mockFoo
# returns <Mock spec='Foo' id='507120'>

mockFoo.doFoo("narf")
mockFoo.doFoo.assert_called_with("narf")
# assertion passes

mockFoo.doFoo("zort")
mockFoo.doFoo.assert_called_with("narf")
# AssertionError: Expected call: doFoo('narf')
# Actual call: doFoo('zort')
assert_called_once_with()

检查功能与 assert_called_with() 一样, 但是只允许调用一次, 超过一次将引发错误.

from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        pass
    
    def doFoo(self, argValue):
        pass

# create the mock object
mockFoo = Mock(spec = Foo)
print mockFoo
# returns <Mock spec='Foo' id='507120'>

mockFoo.callFoo()
mockFoo.callFoo.assert_called_once_with()
# assertion passes

mockFoo.callFoo()
mockFoo.callFoo.assert_called_once_with()
# AssertionError: Expected to be called once. Called 2 times.
assert_any_call()

检查是否调用了某方法.

from mock import Mock

# The mock specification
class Foo(object):
    _fooValue = 123
    
    def callFoo(self):
        pass
    
    def doFoo(self, argValue):
        pass

# create the mock object
mockFoo = Mock(spec = Foo)
print mockFoo
# returns <Mock spec='Foo' id='507120'>

mockFoo.callFoo()
mockFoo.doFoo("narf")
mockFoo.doFoo("zort")

mockFoo.callFoo.assert_any_call()
# assert passes

mockFoo.callFoo()
mockFoo.doFoo("troz")

mockFoo.doFoo.assert_any_call("zort")
# assert passes

mockFoo.doFoo.assert_any_call("egad")
# raises: AssertionError: doFoo('egad') call not found
  • mockFoo.callFoo.assert_any_call() 通过, 因为确实调用了 callFoo()
  • mockFoo.doFoo.assert_any_call("zort") 通过, 因为确实调用了 doFoo("zort")
  • mockFoo.doFoo.assert_any_call("egad") 失败, 因为没调用过 doFoo("egad")
assert_has_calls()

检查调用顺序是否正确, any_order 参数可选.

from mock import Mock, call

# The mock specification
class Foo(object):
    _fooValue = 123
    
    def callFoo(self):
        pass
    
    def doFoo(self, argValue):
        pass

# create the mock object
mockFoo = Mock(spec = Foo)
print mockFoo
# returns <Mock spec='Foo' id='507120'>

mockFoo.callFoo()
mockFoo.doFoo("narf")
mockFoo.doFoo("zort")

fooCalls = [call.callFoo(), call.doFoo("narf"), call.doFoo("zort")]
mockFoo.assert_has_calls(fooCalls)
# assert passes

fooCalls = [call.callFoo(), call.doFoo("zort"), call.doFoo("narf")]
mockFoo.assert_has_calls(fooCalls)
# AssertionError: Calls not found.
# Expected: [call.callFoo(), call.doFoo('zort'), call.doFoo('narf')]
# Actual: [call.callFoo(), call.doFoo('narf'), call.doFoo('zort')]

fooCalls = [call.callFoo(), call.doFoo("zort"), call.doFoo("narf")]
mockFoo.assert_has_calls(fooCalls, any_order = True)
# assert passes

管理 Mock

attach_mock()

该方法让我们在前一个 mock 对象上增加第二个 mock 对象, 并进行重命名.

from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

class Bar(object):
    # instance properties
    _barValue = 456
    
    def callBar(self):
        pass
    
    def doBar(self, argValue):
        pass

# create the first mock object
mockFoo = Mock(spec = Foo)
print mockFoo
# returns <Mock spec='Foo' id='507120'>

# create the second mock object
mockBar = Mock(spec = Bar)
print mockBar
# returns: <Mock spec='Bar' id='2784400'>

# attach the second mock to the first
mockFoo.attach_mock(mockBar, 'fooBar')

# access the first mock's attributes
print mockFoo
# returns: <Mock spec='Foo' id='495312'>
print mockFoo._fooValue
# returns: <Mock name='mock._fooValue' id='428976'>
print mockFoo.callFoo()
# returns: <Mock name='mock.callFoo()' id='448144'>

# access the second mock and its attributes
print mockFoo.fooBar
# returns: <Mock name='mock.fooBar' spec='Bar' id='2788592'>
print mockFoo.fooBar._barValue
# returns: <Mock name='mock.fooBar._barValue' id='2788016'>
print mockFoo.fooBar.callBar()
# returns: <Mock name='mock.fooBar.callBar()' id='2819344'>
print mockFoo.fooBar.doBar("narf")
# returns: <Mock name='mock.fooBar.doBar()' id='4544528'>
configure_mock()

该方法允许我们修改 mock 对象.

from mock import Mock

class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

mockFoo = Mock(spec = Foo, return_value = 555)
print mockFoo()
# returns: 555

mockFoo.configure_mock(return_value = 999)
print mockFoo()
# returns: 999

fooSpec = {'callFoo.return_value':"narf", 'doFoo.return_value':"zort", 'doFoo.side_effect':StandardError}
mockFoo.configure_mock(**fooSpec)

print mockFoo.callFoo()
# returns: narf
print mockFoo.doFoo("narf")
# raises: StandardError

fooSpec = {'doFoo.side_effect':None}
mockFoo.configure_mock(**fooSpec)
print mockFoo.doFoo("narf")
# returns: zort
  • mockFoo.doFoo("narf") 无法通过, 是因为设置了 doFoo.side_effect 为 StandardError
mock_add_spec()

该方法允许我们修改 mock 对象的属性.

from mock import Mock

# The class interfaces
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

class Bar(object):
    # instance properties
    _barValue = 456
    
    def callBar(self):
        pass
    
    def doBar(self, argValue):
        pass
    
# create the mock object
mockFoo = Mock(spec = Foo)

print mockFoo
# returns <Mock spec='Foo' id='507120'>
print mockFoo._fooValue
# returns <Mock name='mock._fooValue' id='2788112'>
print mockFoo.callFoo()
# returns: <Mock name='mock.callFoo()' id='2815376'>

# add a new spec attributes
mockFoo.mock_add_spec(Bar)

print mockFoo
# returns: <Mock spec='Bar' id='491088'>
print mockFoo._barValue
# returns: <Mock name='mock._barValue' id='2815120'>
print mockFoo.callBar()
# returns: <Mock name='mock.callBar()' id='4544368'>

print mockFoo._fooValue
# raises: AttributeError: Mock object has no attribute '_fooValue'
print mockFoo.callFoo()
# raises: AttributeError: Mock object has no attribute 'callFoo'

将 mockFoo 的 spec 属性修改为 Bar 后, 再去访问 Foo 对象的属性, 都会出错.

resetMock()

恢复 mock 对象到测试前的状态, 清除 mock 对象的调用统计和断言, 但不会清除 return_value, side_effect 和 方法属性(如: mockFoo.callFoo.return_value 就是 callFoo 的方法属性).

可以令 mockFoo.callFoo.side_effect = None 来清除 side_effect.

静态方法

called

该属性返回 mock 对象是否被调用过.

from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

# create the first mock object
mockFoo = Mock(spec = Foo)
print mockFoo
# returns <Mock spec='Foo' id='507120'>

print mockFoo.called
# returns: False

mockFoo()
print mockFoo.called
# returns: True

mockFoo = Mock(spec = Foo)
print mockFoo.called
# returns: False

mockFoo.callFoo()
print mockFoo.called
# returns: False

从代码中可以看出, mockFoo.callFoo() 被调用, 不等于 mockFoo 被调用, 所以 mockFoo.called 仍为 False.

call_count

统计 mock 对象的调用次数.

from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

# create the first mock object
mockFoo = Mock(spec = Foo)
print mockFoo
# returns <Mock spec='Foo' id='507120'>

print mockFoo.call_count
# returns: 0

mockFoo()
print mockFoo.call_count
# returns: 1

mockFoo.callFoo()
print mockFoo.call_count
# returns: 1
call_args

返回调用的参数.

from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

# create the first mock object
mockFoo = Mock(spec = Foo, return_value = "narf")
print mockFoo
# returns <Mock spec='Foo' id='507120'>
print mockFoo.call_args
# returns: None

mockFoo("zort")
print mockFoo.call_args
# returns: call('zort')

mockFoo()
print mockFoo.call_args
# returns: call()

mockFoo("troz")
print mockFoo.call_args
# returns: call('troz')

mockFoo.callFoo()
print mockFoo.call_args
# returns: call('troz')

注意, 如果没有参数, call_args 返回上一次的参数.

call_args_list

返回调用过的参数的列表.

from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

# create the first mock object
mockFoo = Mock(spec = Foo, return_value = "narf")
print mockFoo
# returns <Mock spec='Foo' id='507120'>

mockFoo("zort")
print mockFoo.call_args_list
# returns: [call('zort')]

mockFoo()
print mockFoo.call_args_list
# returns: [call('zort'), call()]

mockFoo("troz")
print mockFoo.call_args_list
# returns: [call('zort'), call(), call('troz')]

mockFoo.callFoo()
print mockFoo.call_args_list
# returns: [call('zort'), call(), call('troz')]
method_calls

返回调用方法的列表.

from mock import Mock

# The mock object
class Foo(object):
    # instance properties
    _fooValue = 123
    
    def callFoo(self):
        print "Foo:callFoo_"
    
    def doFoo(self, argValue):
        print "Foo:doFoo:input = ", argValue

# create the first mock object
mockFoo = Mock(spec = Foo, return_value = "poink")
print mockFoo
# returns <Mock spec='Foo' id='507120'>
print mockFoo.method_calls
# returns []

mockFoo()
print mockFoo.method_calls
# returns []

mockFoo.callFoo()
print mockFoo.method_calls
# returns: [call.callFoo()]

mockFoo.doFoo("narf")
print mockFoo.method_calls
# returns: [call.callFoo(), call.doFoo('narf')]

mockFoo()
print mockFoo.method_calls
# returns: [call.callFoo(), call.doFoo('narf')]

测试

assert

最基础的工具, 进行条件判断.

捕获异常

在测试时, 经常需要测试是否如期抛出预期的异常, 以确定异常处理模块生效. 使用 pytest.raises() 进行异常捕获.

def test_raises():
    with pytest.raises(TypeError) as e:
        connect('localhost', '6379')
    exec_msg = e.value.args[0]
    assert exec_msg == 'port type must be int'

标记函数

pytest 会查找当前目录下所有以 test 开头或结尾的 py 文件, 执行文件内所有以 test 开头或结尾的方法. 如果某个方法尚未完成, 可以用以下方法指明不执行该方法:

1 使用 :: 显示指定函数名

pytest test.py::test_func1

这样只会执行 test_func1().

2 使用 -k 模糊匹配

pytest -k func1 test.py

只执行 test.py 中包含 func1 字眼的方法.

3 使用 pytest.mark 在函数上进行标记

@pytest.mark.finished
def test_func1():
    assert 1 == 1

@pytest.mark.unfinished
def test_func2():
    assert 1 != 1

测试时使用 -m 选择标记的测试函数:

pytest -m finished test.py

使用 mark, 可以给每个函数打上不同的标记, 测试时指定只执行有该标记的方法.

还可以使用 pytest.mark.skip, 这样可以在 pytest 时不指定 -m 参数:

@pytest.mark.skip(reason='out-of-date api')
def test_connect():
    pass

$ pytest test.py
...
test.py s

s 表示跳过.

参数化

如果把测试用例都写在测试函数内部进行遍历, 会因为某组用例失败, 导致测试终止.

可以使用 pytest.mark.parametrize(argnames, argvalues) 进行参数化测试, 使得每组参数都独立执行一次.

校验用户密码的例子:

@pytest.mark.parametrize('user, passwd',
                         [('jack', 'abcdefgh'),
                          ('tom', 'a123456a')])
def test_passwd_md5(user, passwd):
    db = {
        'jack': 'e8dc4081b13434b45189a720b77b6818',
        'tom': '1702a132e769a623c1adb78353fc9503'
    }

    import hashlib

    assert hashlib.md5(passwd.encode()).hexdigest() == db[user]

$ pytest -v test.py
...
collected 2 items

tests/test-function/test_parametrize.py::test_passwd_md5[jack-abcdefgh] PASSED [ 50%]
tests/test-function/test_parametrize.py::test_passwd_md5[tom-a123456a] PASSED [100%]

========================== 2 passed in 0.04 seconds ===========================

记得要使用 -v 进行测试.

固件

概念: 固件是一些函数, pytest 会在执行测试函数之前或之后自动运行它们.

可以使用固件来做一些初始化工作和扫尾工作.

固件可以直接定义在各个测试脚本中, 但是更推荐写在 conftest.py 中进行集中管理.

预处理和后处理

pytest 使用 yield 将固件分成两部分, yield 之前的代码属于预处理, 之后的代码属于后处理.

@pytest.fixture()
def db():
    print('Connection successful')

    yield

    print('Connection closed')


def search_user(user_id):
    d = {
        '001': 'xiaoming'
    }
    return d[user_id]


def test_search(db):
    assert search_user('001') == 'xiaoming'

$ pytest -s test.py
============================= test session starts =============================
platform win32 -- Python 3.6.4, pytest-3.6.1, py-1.5.2, pluggy-0.6.0
rootdir: F:\self-repo\learning-pytest, inifile:
collected 1 item

tests\fixture\test_db.py Connection successful
.Connection closed


========================== 1 passed in 0.02 seconds ===========================

-s 参数可以阻止消息被吞

需要将 fixture 装饰的函数, 作为其他需要测试的函数的参数, 如 test_search() 里面, 需要有 db 参数.

作用域

固件的 scope 可以声明固件的作用域, 可选项有:

  • function. 函数级, 每个测试函数都会执行一次固件;
  • class. 类级, 每个测试类都会执行一次;
  • module. 模块级, 每个模块执行一次;
  • session. 会话级, 每次测试执行一次.
@pytest.fixture(scope='function')
def func_scope():
    pass


@pytest.fixture(scope='module')
def mod_scope():
    pass


@pytest.fixture(scope='session')
def sess_scope():
    pass


@pytest.fixture(scope='class')
def class_scope():
    pass

使用方法: 作为测试函数的参数. 如:

def test_multi_scope(sess_scope, mod_scope, func_scope):
    pass

注意, 如果是类作用域, 需要使用 pytest.mark.usefixtures 来指定.

@pytest.mark.usefixtures('class_scope')
class TestClassScope:
    def test_1(self):
        pass

    def test_2(self):
        pass

自动执行固件

如果想让固件自动执行, 可以使用 autouse 参数.

下面两个自动计时固件, 一个用于统计函数运行时间(function 作用域), 一个用于计算测试总耗时(session 作用域)

# test_autouse.py

DATE_FORMAT = '%Y-%m-%d %H:%M:%S'


@pytest.fixture(scope='session', autouse=True)
def timer_session_scope():
    start = time.time()
    print('\nstart: {}'.format(time.strftime(DATE_FORMAT, time.localtime(start))))

    yield

    finished = time.time()
    print('finished: {}'.format(time.strftime(DATE_FORMAT, time.localtime(finished))))
    print('Total time cost: {:.3f}s'.format(finished - start))


@pytest.fixture(autouse=True)
def timer_function_scope():
    start = time.time()
    yield
    print(' Time cost: {:.3f}s'.format(time.time() - start))


def test_1():
    time.sleep(1)


def test_2():
    time.sleep(2)

注意, test_1() 和 test_2() 都没有使用参数, 但 timer_session_scope() 和 timer_function_scope() 都会执行.

conftest.py

如果多个测试文件都要用到同一个功能(如登录), 可以将这个功能写到 conftest.py 文件中, 其他用法仍然不变. 使用 conftest.py 需要注意以下几点:

  • conftest.py 名称固定, 不能修改;
  • conftest.py 与其他测试文件在同一个 package 下, 有 init.py 文件;
  • 不需要 import 导入 conftest.py, pytest 会自己查找.

如:

__init__.py: 空

conftest.py

import pytest


@pytest.fixture()
def db():
    print("connect successful")

    yield

    print("connect closed")

test_fix.py

import pytest


def test_passing(db):
    assert (1, 2, 3) == (1, 2, 3)

pytest-cov

pytest-cov是自动检测测试覆盖率的一个插件.

安装:

pip3 install pytest-cov

使用:

pytest --cov=module_name

或

pytest --cov=./

module_name 是自己的测试模块名.

.coveragerc

如果直接执行 pytest --cov=./ 会把所有的文件都跑一遍, 有时候我们不需要显示所有文件的覆盖率, 可以通过配置文件来忽略一些文件.

在 package 根目录下新建一个 .coveragerc 文件, 内容形如:

[run]
omit=
  test_api.py

执行时指定配置文件:

pytest --cov=./ --cov-config ./.coveragerc

这样就会忽略 test_api.py 文件.

生成 html 格式的测试报告, 并放在 htmlcov 文件夹下:

pytest --cov=./ --cov-report=html test.py

如果想要自己定义文件夹的名称, 就改成: --cov-report=html:dirname

使用 allure 定制报告

pip3 install allure-pytest

生成 xml 报告:

pytest --cov=./ test_api.py --alluredir=./result/

实践

main.py

import time
class Calculator:
    def sum(self, a, b):
        time.sleep(10)
        return a + b

test/test_main.py

import pytest

from mock import patch
from main import Calculator

@patch("main.Calculator.sum")
def test_sum(mock_sum):
    mock_sum.return_value = 9
    calc = Calculator()
    print(calc.sum(2, 4))
    print(mock_sum)

执行命令:

pytest test_main.py --capture=no

输出:

test_main.py 9
<MagicMock name='sum' id='139988835272576'>

可以看到, 我们 mock 了 sum(), 并让其返回 9, 所以在调用 sum(2, 4) 时, 也直接返回 9, 而不是正确值 6.

更复杂的例子

Order 类: 模拟某个项目的采购订单.

Warehouse 类: 测试资源.

OrderTest 类: 测试用例.

Order 类如下. 其中, _orderItem 是采购项目名称, _orderAmount 是要采购的数量, _orderFilled 是已采购的数量.

class Order(object):
    # instance properties
    _orderItem = "None"
    _orderAmount = 0
    _orderFilled = -1
    
    # Constructor
    def __init__(self, argItem, argAmount):
        print "Order:__init__"
        
        # set the order item
        if (isinstance(argItem, str)):
            if (len(argItem) > 0):
                self._orderItem = argItem
        
        # set the order amount
        if (argAmount > 0):
            self._orderAmount = argAmount
        
    # Magic methods
    def __repr__(self):
       # assemble the dictionary
        locOrder = {'item':self._orderItem, 'amount':self._orderAmount}
        return repr(locOrder)
    
    # Instance methods
    # attempt to fill the order
    def fill(self, argSrc):
        print "Order:fill_"
        
        try:
            # does the warehouse has the item in stock?
            if (argSrc is not None):
                if (argSrc.hasInventory(self._orderItem)):
                    # get the item
                    locCount =    argSrc.getInventory(self._orderItem, self._orderAmount)
                
                    # update the following property
                    self._orderFilled = locCount
                else:
                    print "Inventory item not available"
            else:
                print "Warehouse not available"
        except TypeError:
            print "Invalid warehouse"
    
    # check if the order has been filled
    def isFilled(self):
        print "Order:isFilled_"
        return (self._orderAmount == self._orderFilled)

Warehouse 类是一个抽象类. setup() 用于更新属性, hasInventory() 检查仓库中是否有某个项目, getInventory() 扣除项目的名称和数量, addInventory() 是增加项目的名称和数量. 内容如下:

class Warehouse(object):    
    # private properties
    _houseName = None
    _houseList = None
        
    # accessors
    def warehouseName(self):
        return (self._houseName)
    
    def inventory(self):
        return (self._houseList)
    
    
    # -- INVENTORY ACTIONS
    # set up the warehouse
    def setup(self, argName, argList):
    &#9;pass
    
    # check for an inventory item
    def hasInventory(self, argItem):
        pass
    
    # retrieve an inventory item
    def getInventory(self, argItem, argCount):
        pass
        
    # add an inventory item
    def addInventory(self, argItem, argCount):
        pass

OrderTest 类是测试用例本身. fooSource 属性是 Order 类所需的 mock 对象.

import unittest
from mock import Mock, call

class OrderTest(unittest.TestCase):
    # declare the test resource
    fooSource = None
    
    # preparing to test
    def setUp(self):
        """ Setting up for the test """
        print "OrderTest:setUp_:begin"
        
        # identify the test routine
        testName = self.id().split(".")
        testName = testName[2]
        print testName
        
        # prepare and configure the test resource
        if (testName == "testA_newOrder"):
            print "OrderTest:setup_:testA_newOrder:RESERVED"
        elif (testName == "testB_nilInventory"):
            self.fooSource = Mock(spec = Warehouse, return_value = None)
        elif (testName == "testC_orderCheck"):
            self.fooSource = Mock(spec = Warehouse)
            self.fooSource.hasInventory.return_value = True
            self.fooSource.getInventory.return_value = 0
        elif (testName == "testD_orderFilled"):
            self.fooSource = Mock(spec = Warehouse)
            self.fooSource.hasInventory.return_value = True
            self.fooSource.getInventory.return_value = 10
        elif (testName == "testE_orderIncomplete"):
            self.fooSource = Mock(spec = Warehouse)
            self.fooSource.hasInventory.return_value = True
            self.fooSource.getInventory.return_value = 5
        else:
            print "UNSUPPORTED TEST ROUTINE"
    
    # ending the test
    def tearDown(self):
        """Cleaning up after the test"""
        print "OrderTest:tearDown_:begin"
        print ""
    
    # test: new order
    # objective: creating an order
    def testA_newOrder(self):
        # creating a new order
        testOrder = Order("mushrooms", 10)
        print repr(testOrder)
        
        # test for a nil object
        self.assertIsNotNone(testOrder, "Order object is a nil.")
        
        # test for a valid item name
        testName = testOrder._orderItem
        self.assertEqual(testName, "mushrooms", "Invalid item name")
        
        # test for a valid item amount
        testAmount = testOrder._orderAmount
        self.assertGreater(testAmount, 0, "Invalid item amount")
    
    # test: nil inventory
    # objective: how the order object handles a nil inventory
    def testB_nilInventory(self):
        """Test routine B"""
        # creating a new order
        testOrder = Order("mushrooms", 10)
        print repr(testOrder)
        
        # fill the order
        testSource = self.fooSource()
        testOrder.fill(testSource)
        
        # print the mocked calls
        print self.fooSource.mock_calls
        
        # check the call history
        testCalls = [call()]
        self.fooSource.assert_has_calls(testCalls)
    
    def testC_orderCheck(self):
        """Test routine C"""
        # creating a test order
        testOrder = Order("mushrooms", 10)
        print repr(testOrder)
        
        # perform the test
        testOrder.fill(self.fooSource)
        
        # perform the checks
        self.assertFalse(testOrder.isFilled())
        self.assertEqual(testOrder._orderFilled, 0)
        
        self.fooSource.hasInventory.assert_called_once_with("mushrooms")
        print self.fooSource.mock_calls
        
        # creating another order
        testOrder = Order("cabbage", 10)
        print repr(testOrder)
        
        # reconfigure the test resource
        self.fooSource.hasInventory.return_value = False
        self.fooSource.reset_mock()
        
        # perform the test
        testOrder.fill(self.fooSource)
        
        # perform the checks
        self.assertFalse(testOrder.isFilled())
        self.assertEqual(testOrder._orderFilled, -1)
        
        self.fooSource.hasInventory.assert_called_once_with("cabbage")
        print self.fooSource.mock_calls

    def testD_orderFilled(self):
        """Test routine D"""
        # creating a test order
        testOrder = Order("mushrooms", 10)
        print repr(testOrder)
        
        # perform the test
        testOrder.fill(self.fooSource)
        print testOrder.isFilled()
        
        # perform the checks
        self.assertTrue(testOrder.isFilled())
        self.assertNotEqual(testOrder._orderFilled, -1)
        
        self.fooSource.hasInventory.assert_called_once_with("mushrooms")
        self.fooSource.getInventory.assert_called_with("mushrooms", 10)
        
        testCalls = [call.hasInventory("mushrooms"), call.getInventory("mushrooms", 10)]
        self.fooSource.assert_has_calls(testCalls)

    def testE_orderIncomplete(self):
        """Test routine E"""
        # creating a test order
        testOrder = Order("mushrooms", 10)
        print repr(testOrder)
        
        # perform the test
        testOrder.fill(self.fooSource)
        print testOrder.isFilled()
        
        # perform the checks
        self.assertFalse(testOrder.isFilled())
        self.assertNotEqual(testOrder._orderFilled, testOrder._orderAmount)
        
        self.fooSource.hasInventory.assert_called_once_with("mushrooms")
        self.fooSource.getInventory.assert_called_with("mushrooms", 10)
        print self.fooSource.mock_calls
        
        testCalls = [call.hasInventory("mushrooms"), call.getInventory("mushrooms", 10)]
        self.fooSource.assert_has_calls(testCalls)

monkeypatch

转自https://zpzhou.com/archives/monkey_patch.html

所谓猴子补丁, 就是在程序运行的过程中动态修改一些模块, 类或方法, 而不是在静态代码中去修改相应的实现.

如: 小明最爱吃苹果:

class XiaoMing(object):
    def favorite(self):
        print "apple"

# 测试
xiaoming=XiaoMing()
xiaoming.favorite()
>> apple

但是突然有天, 上帝不想让小明喜欢苹果了, 但是小明已经造出来了, 不能再重新修改小明, 所以可以给小明打个 monkeypatch:

class XiaoMing(object):
    def favorite(self):
        print "apple"
        
def new_favorite():
    print "banana"


# 测试
xiaoming=XiaoMing()
xiaoming.favorite()
>> apple

xiaoming.favorite = new_favorite
xiaoming.favorite()
>> banana

换个高级写法:

class XiaoMing(object):
    def favorite(self):
        print "apple"

class God(object):
    @classmethod
    def new_xiaoming_favorite(cls):
        print "banana"

    @classmethod
    def monkey_patch(cls):
        XiaoMing.favorite = cls.new_xiaoming_favorite


# 测试
God.monkey_patch()

xiaoming = XiaoMing()
xiaoming.favorite()
>> banana

原理

namespace

python 中有 namespace 的概念, 是以 dict 的形式实现的, 保存了 name 和对象之间的映射. python 中主要有以下四类 namespace:

- -
类型 描述
locals 函数的 namespace, 只记录当前函数内的对象
enclosing function 记录闭包函数内的对象
globals 模块的 namespace, 记录模块内的 class, function 等
builtins python 内置的 namespace, 在 python 解释器启动时创建, 记录了很多内置函数

在 python 中, 如果要访问一个对象(变量, 模块, 方法等), 都要去 namespace 中根据对象名来检索, 检索顺序为: LEGB, 即: locals -> enclosing function -> globals -> builtins.

如果这四类 namespace 中都找不到指定 name 的对象, 就报 NameError 错误.

模块的导入

python 在启动时会创建一个全局字典: sys.modules, 可以查看 sys.modules 的内容:

import sys
print(sys.modules)

sys.modules 为 dict 类型, key 为模块名, value 为模块内的对象;

导入模块时, 有以下事情发生:

  • 在 sys.modules 中查找该模块, 如果存在则直接导入 sys.module 中模块内的对象;
  • 如果不存在, 则在 sys.modules 中插入 key-value;
  • 将模块内的对象加入到 global namespace 中, 当程序需要调用该模块时, 就到 global namespace 中检索;

monkeypatch 的实现

实现就是 替换掉 sys.modules 中的 key-value.

以 eventlet 库中对 thread, socket 等模块的 monkeypatch 为例:

def monkey_patch(**on):
    """Globally patches certain system modules to be greenthread-friendly.

    The keyword arguments afford some control over which modules are patched.
    If no keyword arguments are supplied, all possible modules are patched.
    If keywords are set to True, only the specified modules are patched.  E.g.,
    ``monkey_patch(socket=True, select=True)`` patches only the select and
    socket modules.  Most arguments patch the single module of the same name
    (os, time, select).  The exceptions are socket, which also patches the ssl
    module if present; and thread, which patches thread, threading, and Queue.

    It's safe to call monkey_patch multiple times.
    """

    # Workaround for import cycle observed as following in monotonic
    # RuntimeError: no suitable implementation for this system
    # see https://github.com/eventlet/eventlet/issues/401#issuecomment-325015989
    #
    # Make sure the hub is completely imported before any
    # monkey-patching, or we risk recursion if the process of importing
    # the hub calls into monkey-patched modules.
    eventlet.hubs.get_hub()

    accepted_args = set(('os', 'select', 'socket',
                         'thread', 'time', 'psycopg', 'MySQLdb',
                         'builtins', 'subprocess'))
    # To make sure only one of them is passed here
    assert not ('__builtin__' in on and 'builtins' in on)
    try:
        b = on.pop('__builtin__')
    except KeyError:
        pass
    else:
        on['builtins'] = b

    default_on = on.pop("all", None)

    for k in six.iterkeys(on):
        if k not in accepted_args:
            raise TypeError("monkey_patch() got an unexpected "
                            "keyword argument %r" % k)
    if default_on is None:
        default_on = not (True in on.values())
    for modname in accepted_args:
        if modname == 'MySQLdb':
            # MySQLdb is only on when explicitly patched for the moment
            on.setdefault(modname, False)
        if modname == 'builtins':
            on.setdefault(modname, False)
        on.setdefault(modname, default_on)

    if on['thread'] and not already_patched.get('thread'):
        _green_existing_locks()

    # 检查哪些模块需要打补丁
    modules_to_patch = []
    for name, modules_function in [
        ('os', _green_os_modules),
        ('select', _green_select_modules),
        ('socket', _green_socket_modules),
        ('thread', _green_thread_modules),
        ('time', _green_time_modules),
        ('MySQLdb', _green_MySQLdb),
        ('builtins', _green_builtins),
        ('subprocess', _green_subprocess_modules),
    ]:
        if on[name] and not already_patched.get(name):
            modules_to_patch += modules_function()
            already_patched[name] = True

    if on['psycopg'] and not already_patched.get('psycopg'):
        try:
            from eventlet.support import psycopg2_patcher
            psycopg2_patcher.make_psycopg_green()
            already_patched['psycopg'] = True
        except ImportError:
            # note that if we get an importerror from trying to
            # monkeypatch psycopg, we will continually retry it
            # whenever monkey_patch is called; this should not be a
            # performance problem but it allows is_monkey_patched to
            # tell us whether or not we succeeded
            pass

    imp.acquire_lock()
    try:
        # 遍历要打补丁的模块
        # 如果还没导入, 就先导入模块
        # 使用 setattr() 替换模块的相关属性
        for name, mod in modules_to_patch:
            orig_mod = sys.modules.get(name)
            if orig_mod is None:
                orig_mod = __import__(name)
            for attr_name in mod.__patched__:
                patched_attr = getattr(mod, attr_name, None)
                if patched_attr is not None:
                    setattr(orig_mod, attr_name, patched_attr)
            deleted = getattr(mod, '__deleted__', [])
            for attr_name in deleted:
                if hasattr(orig_mod, attr_name):
                    delattr(orig_mod, attr_name)
    finally:
        imp.release_lock()

    if sys.version_info >= (3, 3):
        import importlib._bootstrap
        thread = original('_thread')
        # importlib must use real thread locks, not eventlet.Semaphore
        importlib._bootstrap._thread = thread

        # Issue #185: Since Python 3.3, threading.RLock is implemented in C and
        # so call a C function to get the thread identifier, instead of calling
        # threading.get_ident(). Force the Python implementation of RLock which
        # calls threading.get_ident() and so is compatible with eventlet.
        import threading
        threading.RLock = threading._PyRLock

mock 使用小结

mock 的正确写法

假设要测试 mymodule, mymodule 里面使用到了 sysmodule.function, 现在需要 mock sysmodule 里面的这个方法, 错误用法:

@patch("sysmodule.function")

正确用法:

@patch("mymodule.sysmodule.function")

原因:

  • 在 mymodule 里已经导入了 sysmodule, 此时在 mymodule 中, function 指向 sysmodule.function;
  • 此时 @patch(sysmodule.function), 是将 sysmodule.function 指向了 mock.sysmodule.function;
  • 此时, mymodule 中的 function 指向未改变.

如果觉得困惑, 可以自己实验一下. 假如有下面代码:

main.py

from util import sum

class MyObj:
    def get_sum(self):
        return sum(2, 4)

print(dir())

# 输出
['MyObj', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', 'sum']

通过打印 dir() 可以看到 main.py 这个文件下的 namespaces.

虽然 sum 这个方法是属于 util 模块的, 但是通过在 main.py 中 from util import sum, 打印 dir() , 可以发现 sum 已经属于 main.py 中了.

所以 mock 时, 只需要 mock 要测试的模块中的 namespaces 即可(如: main.sum), 而不必去 mock 最初的来源(如: util.sum).

模拟返回多个值

如果要模拟返回多个值, 需要使用元组. 如:

原代码:

a, b, c = func()

模拟 func 的输出, 应该写成:

mock_func.side_effect = (a, b, c)

mock 最底层的方法

假设想判断 log.error() 是否执行, 不要去 mock log, 应该 mock error.

log = logging.getLogger('a')
log.error = Mock()
log.error.assert_called_with('xxx')

断言未发生异常

pytest 可以断言有异常发生, 如:

with pytest.raise(Exception, match=r'xxx'):
    function()

但是没有直接提供方法断言没有异常发生, 如果有这个需求, 可以这么写:

try:
    function()
except Exception:
    pytest.fail("xxx")

Comments

使用 Disqus 评论
comments powered by Disqus